# ---------------------------------------------------------------------
# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
# ---------------------------------------------------------------------

from __future__ import annotations

import functools
from typing import cast

import torch
from mobile_sam import sam_model_registry
from mobile_sam.modeling.mask_decoder import MLP as SAMMaskDecoderMLP
from mobile_sam.modeling.sam import Sam
from mobile_sam.modeling.transformer import TwoWayAttentionBlock, TwoWayTransformer
from mobile_sam.utils.onnx import SamOnnxModel

from qai_hub_models.models._shared.sam.model_patches import (
    Conv2DInplaceLinearSAMMaskDecoderMLP,
    Conv2DInplaceLinearSAMTransformerMLPBlock,
    SplitHeadSAMDecoderAttention,
    sam_decoder_predict_masks,
)
from qai_hub_models.utils.asset_loaders import CachedWebModelAsset
from qai_hub_models.utils.base_model import BaseModel, CollectionModel
from qai_hub_models.utils.input_spec import InputSpec

MODEL_ID = __name__.split(".")[-2]
SMALL_MODEL_TYPE = "vit_t"
DEFAULT_MODEL_TYPE = SMALL_MODEL_TYPE
MODEL_REGISTRY = {
    SMALL_MODEL_TYPE: "mobile_sam.pt",
}
MODEL_ASSET_VERSION = 1
MODEL_ADDRESS = CachedWebModelAsset.from_asset_store(
    MODEL_ID, MODEL_ASSET_VERSION, MODEL_REGISTRY[SMALL_MODEL_TYPE]
)


class MobileSAMEncoder(BaseModel):
    def __init__(self, sam: Sam):
        super().__init__()
        self.sam = sam

    def forward(self, image: torch.Tensor):
        x = self.sam.preprocess(image)
        return self.sam.image_encoder(x)

    @staticmethod
    def get_input_spec(
        batch_size: int = 1,
        encoder_img_height: int = 1024,  # self.sam.image_encoder.img_size
        encoder_img_width: int = 1024,  # self.sam.image_encoder.img_size
    ) -> InputSpec:
        return {
            "image": ((batch_size, 3, encoder_img_height, encoder_img_width), "float32")
        }

    def _get_input_spec_for_instance(
        self,
        batch_size: int = 1,
    ) -> InputSpec:
        return self.__class__.get_input_spec(
            batch_size,
            self.sam.image_encoder.img_size,
            self.sam.image_encoder.img_size,
        )

    @staticmethod
    def get_channel_last_inputs() -> list[str]:
        return ["image"]

    @staticmethod
    def get_output_names() -> list[str]:
        return ["image_embeddings"]

    @staticmethod
    def get_channel_last_outputs() -> list[str]:
        return ["image_embeddings"]

    @classmethod
    def from_pretrained(cls, model_type: str = DEFAULT_MODEL_TYPE) -> MobileSAMEncoder:
        return MobileSAMEncoder(MobileSAMLoader._load_sam_from_repo(model_type))


class MobileSAMDecoder(BaseModel):
    """
    Adapted from from segment_anything.utils.onnx.SamOnnxModel with modifications.

    This removes output mask resizing. Because this requires a dynamic shape to accomplish
    in the network, it's better to do this as a postprocessing step rather than in the inference
    framework itself.
    """

    def __init__(self, sam: Sam, return_single_mask: bool):
        super().__init__(sam)
        self.model: Sam
        self.embed_size = self.model.prompt_encoder.image_embedding_size
        self.img_size = sam.image_encoder.img_size
        self.return_single_mask = return_single_mask

    def _embed_masks(self, input_mask: torch.Tensor | None) -> torch.Tensor:
        """
        Lifted from segment_anything.utils.onnx.SamOnnxModel

        Modified to remove ops based on whether input_mask is set.
        """
        if input_mask is not None:
            return self.model.prompt_encoder.mask_downscaling(input_mask)
        return torch.zeros(
            1, 1, *self.embed_size
        ) + self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)

    def forward(
        self,
        image_embeddings: torch.Tensor,
        point_coords: torch.Tensor,
        point_labels: torch.Tensor,
        mask_input: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Run SAM lightweight decoder and return generated mask for given points

        Parameters
        ----------
            image_embeddings: torch.Tensor of shape [1, emb_dim, emb_size, emb_size]
                Image embeddings generated by Encoder
            point_coords: torch.Tensor of shape [1, k, 2]
                Point coordinates from input image for segmentation
            point_labels: torch.Tensor of shape [1, k]
                Point Labels to select/de-select given point for segmentation
                e.g. Corresponding value is 1 if this point is to be included, otherwise 0
            mask_input: torch.Tensor of shape [1, 1, 4 * self.embed_size, 4 * self.embed_size]
                Input mask to consider for segmentation. If using point based segmentation, this is unused.

        Returns
        -------
            masks: torch.Tensor of shape [1, k, 256, 256]
            scores: torch.Tensor of shape [1, k]

        Where,
            k = number of points
        """
        sparse_embedding = SamOnnxModel._embed_points(self, point_coords, point_labels)
        dense_embedding = self._embed_masks(mask_input)

        masks, scores = sam_decoder_predict_masks(
            self.model.mask_decoder,
            image_embeddings=image_embeddings,
            image_pe=self.model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embedding,
            dense_prompt_embeddings=dense_embedding,
        )

        if self.return_single_mask:
            masks, scores = SamOnnxModel.select_masks(
                self, masks, scores, point_coords.shape[1]
            )

        return masks, scores

    def _get_input_spec_for_instance(
        self: MobileSAMDecoder,
        has_mask_input: bool = False,
        num_of_points: int = 2,
    ) -> InputSpec:
        """
        Override for model.get_input_spec() when called on instances of this class.

        The initializer for BaseModel will automatically override get_input_spec
        with this function when the class is instantiated.
        """
        return self.__class__.get_input_spec(
            has_mask_input,
            num_of_points,
            self.model.prompt_encoder.embed_dim,
            self.embed_size[0],
            self.embed_size[1],
        )

    @staticmethod
    def get_input_spec(
        has_mask_input: bool = False,
        num_of_points: int = 2,
        embed_dim: int = 256,
        image_embedding_height: int = 64,
        image_embedding_width: int = 64,
    ) -> InputSpec:
        # Get the input specification ordered (name -> (shape, type)) pairs for this model.
        #
        # This can be used with the qai_hub python API to declare
        # the model input specification upon submitting a profile job.
        embed_size = (image_embedding_height, image_embedding_width)
        mask_input_size = tuple([4 * x for x in embed_size])

        input_spec: InputSpec = {
            "image_embeddings": ((1, embed_dim, *embed_size), "float32"),
            "point_coords": ((1, num_of_points, 2), "float32"),
            "point_labels": ((1, num_of_points), "float32"),
        }
        if has_mask_input:
            input_spec["mask_input"] = ((1, 1, *mask_input_size), "float32")
            input_spec["has_mask_input"] = ((1,), "float32")
        return input_spec

    @staticmethod
    def get_channel_last_inputs(has_mask_input: bool = False) -> list[str]:
        out = ["image_embeddings"]
        if has_mask_input:
            out.append("mask_input")
        return out

    @staticmethod
    def get_channel_last_outputs() -> list[str]:
        return ["masks"]

    @staticmethod
    def get_output_names() -> list[str]:
        return ["masks", "scores"]

    @classmethod
    def from_pretrained(cls, model_type: str = DEFAULT_MODEL_TYPE) -> MobileSAMDecoder:
        return MobileSAMLoader.load(model_type, True)[2]


class MobileSAMLoader:
    @classmethod
    def load(
        cls,
        model_type: str = DEFAULT_MODEL_TYPE,
        single_mask_mode: bool = True,
    ) -> tuple[Sam, MobileSAMEncoder, MobileSAMDecoder]:
        sam = cls._load_sam_from_repo(model_type)
        cls._patch_mobilesam_for_qnn_comatibility(sam)
        return sam, MobileSAMEncoder(sam), MobileSAMDecoder(sam, single_mask_mode)

    @staticmethod
    def _load_sam_from_repo(model_type: str = DEFAULT_MODEL_TYPE) -> Sam:
        weight_asset = CachedWebModelAsset.from_asset_store(
            MODEL_ID, MODEL_ASSET_VERSION, MODEL_REGISTRY[model_type]
        )
        weight_asset.fetch()
        return sam_model_registry[model_type](weight_asset.path())

    @staticmethod
    def _patch_mobilesam_for_qnn_comatibility(sam: Sam) -> None:
        """Apply a patch to the SAM class for compatibility with QNN."""
        # Normalize pixel_mean and pixel_std for fp ([0, 1]) input
        # Allows network inputs to be float instead of int.
        sam.pixel_mean = sam.pixel_mean / 255.0  # [0-255] -> [0, 1]
        sam.pixel_std = sam.pixel_std / 255.0  # [0-255] -> [0, 1]

        ###
        # Patch the graph for compatibility with QNN.
        #
        # All below optimizations either optimize for QNN inference speed,
        # or fix failures that occur when compiling to QNN.
        ###
        sam.mask_decoder.predict_masks = functools.partial(
            sam_decoder_predict_masks, sam.mask_decoder
        )
        for i in range(len(sam.mask_decoder.output_hypernetworks_mlps)):
            mlp = cast(SAMMaskDecoderMLP, sam.mask_decoder.output_hypernetworks_mlps[i])
            sam.mask_decoder.output_hypernetworks_mlps[i] = (
                Conv2DInplaceLinearSAMMaskDecoderMLP(mlp)
            )
        sam.mask_decoder.iou_prediction_head = Conv2DInplaceLinearSAMMaskDecoderMLP(
            sam.mask_decoder.iou_prediction_head
        )

        transformer = cast(TwoWayTransformer, sam.mask_decoder.transformer)
        transformer.final_attn_token_to_image = SplitHeadSAMDecoderAttention(
            transformer.final_attn_token_to_image
        )
        for block in transformer.layers:
            block = cast(TwoWayAttentionBlock, block)
            block.self_attn = SplitHeadSAMDecoderAttention(block.self_attn)
            block.cross_attn_token_to_image = SplitHeadSAMDecoderAttention(
                block.cross_attn_token_to_image
            )
            block.cross_attn_image_to_token = SplitHeadSAMDecoderAttention(
                block.cross_attn_image_to_token
            )
            block.mlp = Conv2DInplaceLinearSAMTransformerMLPBlock(block.mlp)


@CollectionModel.add_component(MobileSAMEncoder)
@CollectionModel.add_component(MobileSAMDecoder)
class MobileSAM(CollectionModel):
    def __init__(self, sam: Sam, encoder: MobileSAMEncoder, decoder: MobileSAMDecoder):
        super().__init__(encoder, decoder)
        self.sam = sam
        self.encoder = encoder
        self.decoder = decoder

    @classmethod
    def from_pretrained(
        cls, model_type: str = DEFAULT_MODEL_TYPE, single_mask_mode: bool = True
    ) -> MobileSAM:
        return cls(*MobileSAMLoader.load(model_type, single_mask_mode))
